{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EXQz7Vp8ehqb"
      },
      "source": [
        "# 🚀 Getting started\n",
        "\n",
        "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.sandbox.google.com/github/google-deepmind/optax/blob/main/docs/optax-101.ipynb)\n",
        "\n",
        "Optax is a simple optimization library for [JAX](https://jax.readthedocs.io/). The main object is the {py:class}`GradientTransformation <optax.GradientTransformation>`, which can be chained with other transformations to obtain the final update operation and the optimizer state. Optax also contains some simple loss functions and utilities to help you write the full optimization steps. This notebook walks you through a few examples on how to use Optax."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vEIU3POrGiE5"
      },
      "source": [
        "## Example: Fitting a Linear Model\n",
        "\n",
        "Begin by importing the necessary packages:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Jr7_e_ZJ_hky"
      },
      "outputs": [],
      "source": [
        "import jax.numpy as jnp\n",
        "import jax\n",
        "import optax\n",
        "import functools"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n7kMS9kyM8vM"
      },
      "source": [
        "In this example, we begin by setting up a simple linear model and a loss function. You can use any other library, such as [haiku](https://github.com/deepmind/dm-haiku) or [Flax](https://github.com/google/flax) to construct your networks. Here, we keep it simple and write it ourselves. The loss function (L2 Loss) comes from Optax's {doc}`losses <api/losses>` via {py:class}`l2_loss <optax.l2_loss>`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0-8XwoQF_AO2"
      },
      "outputs": [],
      "source": [
        "@functools.partial(jax.vmap, in_axes=(None, 0))\n",
        "def network(params, x):\n",
        "  return jnp.dot(params, x)\n",
        "\n",
        "def compute_loss(params, x, y):\n",
        "  y_pred = network(params, x)\n",
        "  loss = jnp.mean(optax.l2_loss(y_pred, y))\n",
        "  return loss"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EZviuSmuNFsC"
      },
      "source": [
        "Here we generate data under a known linear model (with `target_params=0.5`):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "H-_pwBx6_keL"
      },
      "outputs": [],
      "source": [
        "key = jax.random.PRNGKey(42)\n",
        "target_params = 0.5\n",
        "\n",
        "# Generate some data.\n",
        "xs = jax.random.normal(key, (16, 2))\n",
        "ys = jnp.sum(xs * target_params, axis=-1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Td4Lp3qDNsL3"
      },
      "source": [
        "### Basic usage of Optax\n",
        "\n",
        "Optax contains implementations of {doc}`many popular optimizers <api/optimizers>` that can be used very simply. For example, the gradient transform for the Adam optimizer is available at {py:class}`optax.adam`. For now, let's start by calling the {py:class}`GradientTransformation <optax.GradientTransformation>` object for Adam the `optimizer`. We then initialize the optimizer state using the `init` function and `params` of the network."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rsLXLb5wBeY2"
      },
      "outputs": [],
      "source": [
        "start_learning_rate = 1e-1\n",
        "optimizer = optax.adam(start_learning_rate)\n",
        "\n",
        "# Initialize parameters of the model + optimizer.\n",
        "params = jnp.array([0.0, 0.0])\n",
        "opt_state = optimizer.init(params)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CpAvP1WSnsyM"
      },
      "source": [
        "Next we write the update loop. The {py:class}`GradientTransformation <optax.GradientTransformation>` object contains an `update` function that takes in the current optimizer state and gradients and returns the `updates` that need to be applied to the parameters: `updates, new_opt_state = optimizer.update(grads, opt_state)`.\n",
        "\n",
        "Optax comes with a few simple {doc}`update rules <api/apply_updates>` that apply the updates from the gradient transforms to the current parameters to return new ones: `new_params = optax.apply_updates(params, updates)`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TNkhz_nrB2lx"
      },
      "outputs": [],
      "source": [
        "# A simple update loop.\n",
        "for _ in range(1000):\n",
        "  grads = jax.grad(compute_loss)(params, xs, ys)\n",
        "  updates, opt_state = optimizer.update(grads, opt_state)\n",
        "  params = optax.apply_updates(params, updates)\n",
        "\n",
        "assert jnp.allclose(params, target_params), \\\n",
        "'Optimization should retrive the target params used to generate the data.'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XXEz3j7wPZUH"
      },
      "source": [
        "### Custom optimizers\n",
        "\n",
        "Optax makes it easy to create custom optimizers by {py:class}`chain <optax.chain>`ing gradient transforms. For example, this creates an optimizer based on Adam. Note that the scaling is `-learning_rate` which is an important detail since {py:class}`apply_updates <optax.apply_updates>` is additive."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KQNI2P3YEEgP"
      },
      "outputs": [],
      "source": [
        "# Exponential decay of the learning rate.\n",
        "scheduler = optax.exponential_decay(\n",
        "    init_value=start_learning_rate,\n",
        "    transition_steps=1000,\n",
        "    decay_rate=0.99)\n",
        "\n",
        "# Combining gradient transforms using `optax.chain`.\n",
        "gradient_transform = optax.chain(\n",
        "    optax.clip_by_global_norm(1.0),  # Clip by the gradient by the global norm.\n",
        "    optax.scale_by_adam(),  # Use the updates from adam.\n",
        "    optax.scale_by_schedule(scheduler),  # Use the learning rate from the scheduler.\n",
        "    # Scale updates by -1 since optax.apply_updates is additive and we want to descend on the loss.\n",
        "    optax.scale(-1.0)\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XGUrLKxAEO3j"
      },
      "outputs": [],
      "source": [
        "# Initialize parameters of the model + optimizer.\n",
        "params = jnp.array([0.0, 0.0])  # Recall target_params=0.5.\n",
        "opt_state = gradient_transform.init(params)\n",
        "\n",
        "# A simple update loop.\n",
        "for _ in range(1000):\n",
        "  grads = jax.grad(compute_loss)(params, xs, ys)\n",
        "  updates, opt_state = gradient_transform.update(grads, opt_state)\n",
        "  params = optax.apply_updates(params, updates)\n",
        "\n",
        "assert jnp.allclose(params, target_params), \\\n",
        "'Optimization should retrive the target params used to generate the data.'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pIxKL7WsXFl8"
      },
      "source": [
        "### Advanced usage of Optax"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nCtNiVTsZVt2"
      },
      "source": [
        "#### Modifying hyperparameters of optimizers in a schedule.\n",
        "\n",
        "In some scenarios, changing the hyperparameters (other than the learning rate) of an optimizer can be useful to ensure training reliability. We can do this easily by using {py:class}`inject_hyperparams <optax.inject_hyperparams>`. For example, this piece of code decays the `max_norm` of the {py:class}`clip_by_global_norm <optax.clip_by_global_norm>` gradient transform as training progresses:\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NR9Flsj7ZdpC"
      },
      "outputs": [],
      "source": [
        "decaying_global_norm_tx = optax.inject_hyperparams(optax.clip_by_global_norm)(\n",
        "    max_norm=optax.linear_schedule(1.0, 0.0, transition_steps=99))\n",
        "\n",
        "opt_state = decaying_global_norm_tx.init(None)\n",
        "assert opt_state.hyperparams['max_norm'] == 1.0, 'Max norm should start at 1.0'\n",
        "\n",
        "for _ in range(100):\n",
        "  _, opt_state = decaying_global_norm_tx.update(None, opt_state)\n",
        "\n",
        "assert opt_state.hyperparams['max_norm'] == 0.0, 'Max norm should end at 0.0'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tKcocLxEyYf2"
      },
      "source": [
        "## Example: Fitting a MLP\n",
        "\n",
        "Let's use Optax to fit a parametrized function. We will consider the problem of learning to identify when a value is odd or even.\n",
        "\n",
        "We will begin by creating a dataset that consists of batches of random 8 bit integers (represented using their binary representation), with each value labelled as \"odd\" or \"even\" using 1-hot encoding (i.e. `[1, 0]` means odd `[0, 1]` means even).\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Gg6zyMBqydty"
      },
      "outputs": [],
      "source": [
        "import optax\n",
        "import jax.numpy as jnp\n",
        "import jax\n",
        "import numpy as np\n",
        "\n",
        "BATCH_SIZE = 5\n",
        "NUM_TRAIN_STEPS = 1_000\n",
        "RAW_TRAINING_DATA = np.random.randint(255, size=(NUM_TRAIN_STEPS, BATCH_SIZE, 1))\n",
        "\n",
        "TRAINING_DATA = np.unpackbits(RAW_TRAINING_DATA.astype(np.uint8), axis=-1)\n",
        "LABELS = jax.nn.one_hot(RAW_TRAINING_DATA % 2, 2).astype(jnp.float32).reshape(NUM_TRAIN_STEPS, BATCH_SIZE, 2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nV79rjQK8tvC"
      },
      "source": [
        "We may now define a parametrized function using JAX. This will allow us to efficiently compute gradients.\n",
        "\n",
        "There are a number of libraries that provide common building blocks for parametrized functions (such as flax and haiku). For this case though, we shall implement our function from scratch.\n",
        "\n",
        "Our function will be a 1-layer MLP (multi-layer perceptron) with a single hidden layer, and a single output layer. We initialize all parameters using a standard Gaussian {math}`\\mathcal{N}(0,1)` distribution."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Syp9LJ338h9-"
      },
      "outputs": [],
      "source": [
        "initial_params = {\n",
        "    'hidden': jax.random.normal(shape=[8, 32], key=jax.random.PRNGKey(0)),\n",
        "    'output': jax.random.normal(shape=[32, 2], key=jax.random.PRNGKey(1)),\n",
        "}\n",
        "\n",
        "\n",
        "def net(x: jnp.ndarray, params: optax.Params) -> jnp.ndarray:\n",
        "  x = jnp.dot(x, params['hidden'])\n",
        "  x = jax.nn.relu(x)\n",
        "  x = jnp.dot(x, params['output'])\n",
        "  return x\n",
        "\n",
        "\n",
        "def loss(params: optax.Params, batch: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:\n",
        "  y_hat = net(batch, params)\n",
        "\n",
        "  # optax also provides a number of common loss functions.\n",
        "  loss_value = optax.sigmoid_binary_cross_entropy(y_hat, labels).sum(axis=-1)\n",
        "\n",
        "  return loss_value.mean()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2LVHrJyH9vDe"
      },
      "source": [
        "We will use {py:class}`optax.adam` to compute the parameter updates from their gradients on each optimizer step.\n",
        "\n",
        "Note that since Optax optimizers are implemented using pure functions, we will need to also keep track of the optimizer state. For the Adam optimizer, this state will contain the momentum values."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 791,
          "status": "ok",
          "timestamp": 1706783722580,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "JsbPBTF09FGY",
        "outputId": "8ed2b4d2-5d60-4f48-bba0-aeb8bf575cf2"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "step 0, loss: 5.830467224121094\n",
            "step 100, loss: 0.11273854970932007\n",
            "step 200, loss: 0.07734161615371704\n",
            "step 300, loss: 0.04497973248362541\n",
            "step 400, loss: 0.006239257287234068\n",
            "step 500, loss: 0.006817951332777739\n",
            "step 600, loss: 0.0020168141927570105\n",
            "step 700, loss: 0.010563415475189686\n",
            "step 800, loss: 0.0004556340572889894\n",
            "step 900, loss: 0.011806081980466843\n"
          ]
        }
      ],
      "source": [
        "def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:\n",
        "  opt_state = optimizer.init(params)\n",
        "\n",
        "  @jax.jit\n",
        "  def step(params, opt_state, batch, labels):\n",
        "    loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)\n",
        "    updates, opt_state = optimizer.update(grads, opt_state, params)\n",
        "    params = optax.apply_updates(params, updates)\n",
        "    return params, opt_state, loss_value\n",
        "\n",
        "  for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):\n",
        "    params, opt_state, loss_value = step(params, opt_state, batch, labels)\n",
        "    if i % 100 == 0:\n",
        "      print(f'step {i}, loss: {loss_value}')\n",
        "\n",
        "  return params\n",
        "\n",
        "# Finally, we can fit our parametrized function using the Adam optimizer\n",
        "# provided by optax.\n",
        "optimizer = optax.adam(learning_rate=1e-2)\n",
        "params = fit(initial_params, optimizer)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kTaBLYL8_Ppz"
      },
      "source": [
        "We see that our loss appears to have converged, which should indicate that we have successfully found better parameters for our network."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qT_Uaei5Dv_3"
      },
      "source": [
        "### Weight Decay, Schedules and Clipping\n",
        "\n",
        "Many research models make use of techniques such as learning rate scheduling, and gradient clipping. These may be achieved by chaining together gradient transformations such as {py:class}`optax.adam` and {py:class}`optax.clip`.\n",
        "\n",
        "In the following, we will use Adam with weight decay ({py:class}`optax.adamw`), a cosine learning rate schedule (with warmup) and also gradient clipping."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 257,
          "status": "ok",
          "timestamp": 1706783722955,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "SZegYQajDtLi",
        "outputId": "00d3b1f8-d44d-4ab8-c47a-dc557dd327c3"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "step 0, loss: 5.830467224121094\n",
            "step 100, loss: 0.0\n",
            "step 200, loss: 1.2290412909483864e-18\n",
            "step 300, loss: 0.0\n",
            "step 400, loss: 0.0\n",
            "step 500, loss: 0.0\n",
            "step 600, loss: 0.0\n",
            "step 700, loss: 0.0\n",
            "step 800, loss: 0.0\n",
            "step 900, loss: 0.0\n"
          ]
        }
      ],
      "source": [
        "schedule = optax.warmup_cosine_decay_schedule(\n",
        "  init_value=0.0,\n",
        "  peak_value=1.0,\n",
        "  warmup_steps=50,\n",
        "  decay_steps=1_000,\n",
        "  end_value=0.0,\n",
        ")\n",
        "\n",
        "optimizer = optax.chain(\n",
        "  optax.clip(1.0),\n",
        "  optax.adamw(learning_rate=schedule),\n",
        ")\n",
        "\n",
        "params = fit(initial_params, optimizer)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qf53Y6mT1Vwl"
      },
      "source": [
        "## Components\n",
        "\n",
        "We refer to the {doc}`docs <index>` for a detailed list of available Optax components. Here, we highlight the main categories of building blocks provided by Optax."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WZFpEKi82TGx"
      },
      "source": [
        "### Gradient Transformations ([transform.py](https://github.com/google-deepmind/optax/blob/main/optax/_src/transform.py))\n",
        "\n",
        "One of the key building blocks of Optax is a {py:class}`GradientTransformation <optax.GradientTransformation>`. Each transformation is defined by two functions:\n",
        "\n",
        "`state = init(params)`\n",
        "\n",
        "`grads, state = update(grads, state, params=None)`"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n6SsC9lNGiE-"
      },
      "source": [
        "The `init` function initializes a (possibly empty) set of statistics (aka state) and the `update` function transforms a candidate gradient given some statistics, and (optionally) the current value of the parameters.\n",
        "\n",
        "For example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_yCQbSCc2KhJ"
      },
      "outputs": [],
      "source": [
        "tx = optax.scale_by_rms()\n",
        "state = tx.init(params)  # init stats\n",
        "grads = jax.grad(loss)(params, TRAINING_DATA, LABELS)\n",
        "updates, state = tx.update(grads, state, params)  # transform & update stats."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TyxJmbBq2xT6"
      },
      "source": [
        "### Composing Gradient Transformations ([combine.py](https://github.com/google-deepmind/optax/blob/main/optax/_src/combine.py))\n",
        "\n",
        "The fact that transformations take candidate gradients as input and return processed gradients as output (in contrast to returning the updated parameters) is critical to allow to combine arbitrary transformations into a custom optimiser / gradient processor, and also allows to combine transformations for different gradients that operate on a shared set of variables.\n",
        "\n",
        "For instance, {py:class}`chain <optax.chain>` combines them sequentially, and returns a new {py:class}`GradientTransformation <optax.GradientTransformation>` that applies several transformations in sequence.\n",
        "\n",
        "For example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TNPC9e7I28m8"
      },
      "outputs": [],
      "source": [
        "max_norm = 100.\n",
        "learning_rate = 1e-3\n",
        "\n",
        "my_optimiser = optax.chain(\n",
        "    optax.clip_by_global_norm(max_norm),\n",
        "    optax.scale_by_adam(eps=1e-4),\n",
        "    optax.scale(-learning_rate))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JmV92-PI2_pS"
      },
      "source": [
        "### Wrapping Gradient Transformations ([wrappers.py](https://github.com/google-deepmind/optax/blob/main/optax/_src/wrappers.py))\n",
        "\n",
        "Optax also provides several wrappers that take a {py:class}`GradientTransformation <optax.GradientTransformation>` as input and return a new {py:class}`GradientTransformation <optax.GradientTransformation>` that modifies the behaviour of the inner transformation in a specific way.\n",
        "\n",
        "For instance, the {py:class}`flatten <optax.flatten>` wrapper flattens gradients into a single large vector before applying the inner {py:class}`GradientTransformation <optax.GradientTransformation>`. The transformed updates are then unflattened before being returned to the user. This can be used to reduce the overhead of performing many calculations on lots of small variables, at the cost of increasing memory usage.\n",
        "\n",
        "For example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b1TlMbAk3Jbo"
      },
      "outputs": [],
      "source": [
        "my_optimiser = optax.flatten(optax.adam(learning_rate))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IUCIMymV3M2n"
      },
      "source": [
        "Other examples of wrappers include accumulating gradients over multiple steps or applying the inner transformation only to specific parameters or at specific steps."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AGAmqST33PkO"
      },
      "source": [
        "### Schedules ([schedule.py](https://github.com/google-deepmind/optax/blob/main/optax/_src/schedule.py))\n",
        "\n",
        "Many popular transformations use time-dependent components, e.g. to anneal some hyper-parameter (e.g. the learning rate). Optax provides for this purpose schedules that can be used to decay scalars as a function of a `step` count.\n",
        "\n",
        "For example, you may use a {py:class}`polynomial_schedule <optax.polynomial_schedule>` (with `power=1`) to decay a hyper-parameter linearly over a number of steps:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 131,
          "status": "ok",
          "timestamp": 1706783724487,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "Zbr61DLP3ecy",
        "outputId": "20beae28-32bf-4030-f877-fcd4a815a0c2"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1.0\n",
            "0.8\n",
            "0.6\n",
            "0.39999998\n",
            "0.19999999\n",
            "0.0\n"
          ]
        }
      ],
      "source": [
        "schedule_fn = optax.polynomial_schedule(\n",
        "    init_value=1., end_value=0., power=1, transition_steps=5)\n",
        "\n",
        "for step_count in range(6):\n",
        "  print(schedule_fn(step_count))  # [1., 0.8, 0.6, 0.4, 0.2, 0.]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LGt0AzHF3fjR"
      },
      "source": [
        "Schedules can be combined with other transforms as follows."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W9oCb0Kw3igG"
      },
      "outputs": [],
      "source": [
        "schedule_fn = optax.polynomial_schedule(\n",
        "    init_value=-learning_rate, end_value=0., power=1, transition_steps=5)\n",
        "optimiser = optax.chain(\n",
        "    optax.clip_by_global_norm(max_norm),\n",
        "    optax.scale_by_adam(eps=1e-4),\n",
        "    optax.scale_by_schedule(schedule_fn))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sDSXlRAN_B2F"
      },
      "source": [
        "Schedules can also be used in place of the `learning_rate` argument of a\n",
        "{py:class}`GradientTransformation <optax.GradientTransformation>` as\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zyvlGLDw_BKk"
      },
      "outputs": [],
      "source": [
        "optimiser = optax.adam(learning_rate=schedule_fn)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cKHZrM203kx4"
      },
      "source": [
        "### Popular optimisers ([alias.py](https://github.com/google-deepmind/optax/blob/main/optax/_src/alias.py))\n",
        "\n",
        "In addition to the low-level building blocks, we also provide aliases for popular optimisers built using these components (e.g. RMSProp, Adam, AdamW, etc, ...). These are all still instances of a {py:class}`GradientTransformation <optax.GradientTransformation>`, and can therefore be further combined with any of the individual building blocks.\n",
        "\n",
        "For example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Czk49AQz3w1J"
      },
      "outputs": [],
      "source": [
        "def adamw(learning_rate, b1, b2, eps, weight_decay):\n",
        "  return optax.chain(\n",
        "      optax.scale_by_adam(b1=b1, b2=b2, eps=eps),\n",
        "      optax.scale_and_decay(-learning_rate, weight_decay=weight_decay))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j0tD_jWC3zar"
      },
      "source": [
        "### Applying updates ([update.py](https://github.com/google-deepmind/optax/blob/main/optax/_src/update.py))\n",
        "\n",
        "After transforming an update using a {py:class}`GradientTransformation <optax.GradientTransformation>` or any custom manipulation of the update, you will typically apply the update to a set of parameters. This can be done trivially using `tree_map`.\n",
        "\n",
        "For convenience, we expose an {py:class}`apply_updates <optax.apply_updates>` function to apply updates to parameters. The function just adds the updates and the parameters together, i.e. `tree_map(lambda p, u: p + u, params, updates)`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YG-TNzYm4CHt"
      },
      "outputs": [],
      "source": [
        "updates, state = tx.update(grads, state, params)  # transform & update stats.\n",
        "new_params = optax.apply_updates(params, updates)  # update the parameters."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eg85y6_s4C2c"
      },
      "source": [
        "Note that separating gradient transformations from the parameter update is critical to support composing a sequence of transformations (e.g. {py:class}`chain <optax.chain>`), as well as combining multiple updates to the same parameters (e.g. in multi-task settings where different tasks need different sets of gradient transformations)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dJzW0Flw4FP5"
      },
      "source": [
        "### Losses ([loss.py](https://github.com/google-deepmind/optax/tree/main/optax/losses))\n",
        "\n",
        "Optax provides a number of standard losses used in deep learning, such as {py:class}`l2_loss <optax.l2_loss>`, {py:class}`softmax_cross_entropy <optax.softmax_cross_entropy>`, {py:class}`cosine_distance <optax.cosine_distance>`, etc."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8JCWgHhJ4PMc"
      },
      "outputs": [],
      "source": [
        "predictions = net(TRAINING_DATA, params)\n",
        "loss = optax.huber_loss(predictions, LABELS)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gAlaEpgQ4QyD"
      },
      "source": [
        "The losses accept batches as inputs, however, they perform no reduction across the batch dimension(s). This is trivial to do in JAX, for example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "45svU6Qr4ThD"
      },
      "outputs": [],
      "source": [
        "avg_loss = jnp.mean(optax.huber_loss(predictions, LABELS))\n",
        "sum_loss = jnp.sum(optax.huber_loss(predictions, LABELS))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MepQR-Cr4VaB"
      },
      "source": [
        "### Second Order ([second_order.py](https://github.com/google-deepmind/optax/tree/main/optax/second_order))\n",
        "\n",
        "Computing the Hessian or Fisher information matrices for neural networks is typically intractable due to the quadratic memory requirements. Solving for the diagonals of these matrices is often a better solution. The library offers functions for computing these diagonals with sub-quadratic memory requirements."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fcJiCWSQ4gPP"
      },
      "source": [
        "### Stochastic gradient estimators ([stochastic_gradient_estimators.py](https://github.com/google-deepmind/optax/blob/main/optax/monte_carlo/stochastic_gradient_estimators.py))\n",
        "\n",
        "Stochastic gradient estimators compute Monte Carlo estimates of gradients of the expectation of a function under a distribution with respect to the distribution's parameters.\n",
        "\n",
        "Unbiased estimators, such as the score function estimator (REINFORCE), pathwise estimator (reparameterization trick) or measure valued estimator, are implemented: {py:class}`score_function_jacobians <optax.monte_carlo.score_function_jacobians>`, {py:class}`pathwise_jacobians <optax.monte_carlo.pathwise_jacobians>` and {py:class}`measure_valued_jacobians <optax.monte_carlo.measure_valued_jacobians>`. Their applicability (both in terms of functions and distributions) is discussed in their respective documentation.\n",
        "\n",
        "Stochastic gradient estimators can be combined with common control variates for variance reduction via {py:class}`control_variates_jacobians <optax.monte_carlo.control_variates_jacobians>`. For provided control variates see {py:class}`control_delta_method <optax.monte_carlo.control_delta_method>` and {py:class}`moving_avg_baseline <optax.monte_carlo.moving_avg_baseline>`.\n",
        "\n",
        "The result of a gradient estimator or {py:class}`control_variates_jacobians <optax.monte_carlo.control_variates_jacobians>` contains the Jacobians of the function with respect to the samples from the input distribution. These can then be used to update distributional parameters or to assess gradient variance.\n",
        "\n",
        "Example of how to use the {py:class}`pathwise_jacobians <optax.monte_carlo.pathwise_jacobians>` estimator:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NYJOV6Vv4uDb"
      },
      "outputs": [],
      "source": [
        "mean, log_scale, rng, num_samples = 0., 1., jax.random.PRNGKey(0), 100\n",
        "dist_params = [mean, log_scale]\n",
        "function = lambda x: jnp.sum(x)\n",
        "jacobians = optax.monte_carlo.pathwise_jacobians(\n",
        "      function, dist_params,\n",
        "      optax.multi_normal, rng, num_samples)\n",
        "\n",
        "mean_grads = jnp.mean(jacobians[0], axis=0)\n",
        "log_scale_grads = jnp.mean(jacobians[1], axis=0)\n",
        "grads = [mean_grads, log_scale_grads]\n",
        "optim = optax.adam(1e-3)\n",
        "optim_state = optim.init(grads)\n",
        "optim_update, optim_state = optim.update(grads, optim_state)\n",
        "updated_dist_params = optax.apply_updates(dist_params, optim_update)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x5mXAl_H4xCH"
      },
      "source": [
        "where `optim` is an Optax optimizer."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {
        "build_target": "//learning/grp/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "Optax 101",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
